---
title: "EpiClass accurately predicts EpiATLAS assay and biospecimen metadata"
author: "Joanny Raby"
resources:
- "../resources/threshold_impact_graph_global_assay_epiclass_11c-11_classes_cdn.html"
- "../resources/threshold_impact_graph_global_harmonized_sample_ontology_intermediate-16_classes_cdn.html"
format:
html:
code-fold: true
code-tools: true
toc: true
toc-location: right
toc-expand: 2
embed-resources: true
engine: jupyter
execute:
echo: true
warning: false
eval: true
error: false
---
# Results section 1 figures
The formatting of the figures may differ slightly from those in the paper, but they display the same data points.
All code cells are folded by default. To view any cell, click **"Code"** to expand it, or use the code options near the main title above to unfold all at once.
Some code may be repeated, as the original Python notebook was designed for figures to be generated semi-independently.
## Setup Code - Imports and co.
Setup imports.
```{python}
#| label: setup-imports
from __future__ import annotations
import copy
import logging
import re
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional, Tuple
import re
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.core.display import Image
from IPython.display import display
from plotly.subplots import make_subplots
from sklearn.metrics import auc, confusion_matrix as sk_cm, roc_curve
from sklearn.preprocessing import label_binarize
from epiclass.core.confusion_matrix import ConfusionMatrixWriter
from epiclass.utils.notebooks.paper.metrics_per_assay import MetricsPerAssay
from epiclass.utils.notebooks.paper.paper_utilities import (
ASSAY,
ASSAY_MERGE_DICT,
ASSAY_ORDER,
CELL_TYPE,
SEX,
IHECColorMap,
MetadataHandler,
SplitResultsHandler,
extract_input_sizes_from_output_files,
merge_similar_assays,
PathChecker
)
```
Setup paths.
```{python}
#| label: setup-paths
# Root path
base_dir = Path.home() / "Projects/epiclass/output/paper"
PathChecker.check_directory(base_dir)
# More precise
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
# alias
paper_dir = base_dir
```
Setup colors.
```{python}
#| label: setup-colors
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map
```
Setup metadata and prediction files handlers.
```{python}
#| label: setup-handlers
split_results_handler = SplitResultsHandler()
metadata_handler = MetadataHandler(paper_dir)
metadata_v2 = metadata_handler.load_metadata("v2")
metadata_v2_df = metadata_v2.to_df()
```
Setup data directories.
```{python}
#| label: setup-data-dirs
gen_data_dir = base_data_dir / "training_results" / "dfreeze_v2"
PathChecker.check_directory(gen_data_dir)
data_dir_100kb = gen_data_dir / "hg38_100kb_all_none"
PathChecker.check_directory(data_dir_100kb)
```
Setup figures general settings.
```{python}
#| label: setup-figs-settings
main_title_settings = {
"title":dict(
automargin=True,
x=0.5,
xanchor="center",
yanchor="top",
y=0.98
),
"margin":dict(t=50, l=10, r=10)
}
```
## Figure 1
Performance of EpiClass Assay and Biospecimen classifiers.
### A - EpiClass training overview
{width=.column-body}
Fig. 1A: Overview of the EpiClass training process for various classifiers and their inference on external data. Each classifier is trained independently.
### B-C Prep
Path setup.
```{python}
#| label: fig1bc-prep-paths
mixed_data_dir = gen_data_dir / "mixed"
PathChecker.check_directory(data_dir_100kb)
```
Feature sets setup.
```{python}
#| label: fig1bc-prep-feature-sets
feature_sets_14 = [
"hg38_10mb_all_none_1mb_coord",
"hg38_100kb_random_n316_none",
"hg38_1mb_all_none",
"hg38_100kb_random_n3044_none",
"hg38_100kb_all_none",
"hg38_gene_regions_100kb_coord_n19864",
"hg38_10kb_random_n30321_none",
"hg38_regulReg_allCorr_n30k",
"hg38_1kb_random_n30321_none",
"hg38_cpg_topvar_200bp_n30k_wrong_coordinates",
"hg38_10kb_all_none",
"hg38_regulReg_allCorr_n303k",
"hg38_1kb_random_n303114_none",
"hg38_cpg_topvar_200bp_n303k_wrong_coordinates",
]
fig1_sets = [
"hg38_10mb_all_none_1mb_coord",
"hg38_100kb_random_n316_none",
"hg38_1mb_all_none",
"hg38_100kb_random_n3044_none",
"hg38_100kb_all_none",
"hg38_10kb_random_n30321_none",
"hg38_1kb_random_n30321_none",
"hg38_10kb_all_none",
"hg38_1kb_random_n303114_none",
]
metric_orders_map = {
"fig1_sets": fig1_sets,
"feature_sets_14": feature_sets_14,
}
```
Compute input sizes for each feature set.
```{python}
#| label: fig1bc-prep-input-sizes
input_sizes = extract_input_sizes_from_output_files(mixed_data_dir) # type: ignore
input_sizes: Dict[str, int] = {k: v.pop() for k, v in input_sizes.items() if len(v) == 1} # type: ignore
```
Set selection.
```{python}
#| label: fig1bc-prep-set-selection
set_selection_name = "feature_sets_14"
logdir = (
base_fig_dir
/ "fig2_EpiAtlas_other"
/ "fig2--reduced_feature_sets"
/ "test"
/ set_selection_name
)
logdir.mkdir(parents=True, exist_ok=True)
```
Compute metrics.
```{python}
#| label: fig1bc-prep-metrics
all_metrics = split_results_handler.obtain_all_feature_set_data(
parent_folder=mixed_data_dir,
merge_assays=True,
return_type="metrics",
include_categories=[ASSAY, CELL_TYPE],
include_sets=metric_orders_map[set_selection_name],
exclude_names=["16ct", "27ct", "7c", "chip-seq-only"],
)
# Order the metrics
all_metrics = {
name: all_metrics[name] # type: ignore
for name in metric_orders_map[set_selection_name]
if name in all_metrics
}
```
Label correction.
```{python}
#| label: fig1bc-prep-correct-name
# correct a name
try:
all_metrics["hg38_100kb_all_none"][ASSAY] = all_metrics["hg38_100kb_all_none"][ # type: ignore
f"{ASSAY}_11c"
]
del all_metrics["hg38_100kb_all_none"][f"{ASSAY}_11c"]
except KeyError:
pass
```
Resolution/feature set --> color mapping.
```{python}
#| label: fig1bc-prep-resolution-colors
resolution_colors = {
"100kb": px.colors.qualitative.Safe[0],
"10kb": px.colors.qualitative.Safe[1],
"1kb": px.colors.qualitative.Safe[2],
"regulReg": px.colors.qualitative.Safe[3],
"gene": px.colors.qualitative.Safe[4],
"cpg": px.colors.qualitative.Safe[5],
"1mb": px.colors.qualitative.Safe[6],
"5mb": px.colors.qualitative.Safe[7],
"10mb": px.colors.qualitative.Safe[8],
}
```
Define graphing function `graph_feature_set_metrics`.
```{python}
#| label: fig1bc-prep-graph-function
def graph_feature_set_metrics(
all_metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
input_sizes: Dict[str, int],
logdir: Path | None = None,
sort_by_input_size: bool = False,
name: str | None = None,
y_range: Tuple[float, float] | None = None,
boxpoints: str = "all",
width: int = 1200,
height: int = 1200,
) -> None:
"""Graph the metrics for all feature sets.
Args:
all_metrics (Dict[str, Dict[str, Dict[str, Dict[str, float]]]): A dictionary containing all metrics for all feature sets.
Format: {feature_set: {task_name: {split_name: metric_dict}}}
input_sizes (Dict[str, int]): A dictionary containing the input sizes for all feature sets.
logdir (Path): The directory where the figure will be saved. If None, the figure will only be displayed.
sort_by_input_size (bool): Whether to sort the feature sets by input size.
name (str|None): The name of the figure.
y_range (Tuple[float, float]|None): The y-axis range for the figure.
boxpoints (str): The type of boxpoints to display. Can be "all" or "outliers". Defaults to "all".
"""
if boxpoints not in ["all", "outliers"]:
raise ValueError("Invalid boxpoints value.")
reference_hdf5_type = "hg38_100kb_all_none"
metadata_categories = list(all_metrics[reference_hdf5_type].keys())
non_standard_names = {ASSAY: f"{ASSAY}_11c", SEX: f"{SEX}_w-mixed"}
non_standard_assay_task_names = ["hg38_100kb_all_none"]
used_resolutions = set()
for i in range(len(metadata_categories)):
category_idx = i
category_fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=True,
subplot_titles=["Accuracy", "F1-score (macro)"],
horizontal_spacing=0.01,
)
trace_names = []
order = list(all_metrics.keys())
if sort_by_input_size:
order = sorted(
all_metrics.keys(),
key=lambda x: input_sizes[x],
)
for feature_set_name in order:
# print(feature_set_name)
tasks_dicts = all_metrics[feature_set_name]
meta_categories = copy.deepcopy(metadata_categories)
if feature_set_name not in input_sizes:
print(f"Skipping {feature_set_name}, no input size found.")
continue
task_name = meta_categories[category_idx]
if "split" in task_name:
raise ValueError("Split in task name. Wrong metrics dict.")
try:
task_dict = tasks_dicts[task_name]
except KeyError as err:
if SEX in str(err) and feature_set_name in non_standard_sex_task_name:
task_dict = tasks_dicts[non_standard_names[SEX]]
elif (
ASSAY in str(err)
and feature_set_name in non_standard_assay_task_names
):
task_dict = tasks_dicts[non_standard_names[ASSAY]]
else:
print("Skipping", feature_set_name, task_name)
continue
input_size = input_sizes[feature_set_name]
feature_set_name = feature_set_name.replace("_none", "").replace("hg38_", "")
feature_set_name = re.sub(r"\_[\dmkb]+\_coord", "", feature_set_name)
resolution = feature_set_name.split("_")[0]
used_resolutions.add(resolution)
trace_name = f"{input_size}|{feature_set_name}"
trace_names.append(trace_name)
# Accuracy
metric = "Accuracy"
y_vals = [task_dict[split][metric] for split in task_dict]
hovertext = [
f"{split}: {metrics_dict[metric]:.4f}"
for split, metrics_dict in task_dict.items()
]
category_fig.add_trace(
go.Box(
y=y_vals,
name=trace_name,
boxmean=True,
boxpoints=boxpoints,
marker=dict(size=3, color="black"),
line=dict(width=1, color="black"),
fillcolor=resolution_colors[resolution],
hovertemplate="%{text}",
text=hovertext,
legendgroup=resolution,
showlegend=False,
),
row=1,
col=1,
)
metric = "F1_macro"
y_vals = [task_dict[split][metric] for split in task_dict]
hovertext = [
f"{split}: {metrics_dict[metric]:.4f}"
for split, metrics_dict in task_dict.items()
]
category_fig.add_trace(
go.Box(
y=y_vals,
name=trace_name,
boxmean=True,
boxpoints=boxpoints,
marker=dict(size=3, color="black"),
line=dict(width=1, color="black"),
fillcolor=resolution_colors[resolution],
hovertemplate="%{text}",
text=hovertext,
legendgroup=resolution,
showlegend=False,
),
row=1,
col=2,
)
title = f"{metadata_categories[category_idx]} classification"
title = title.replace(CELL_TYPE, "biospecimen")
if name is not None:
title += f" - {name}"
category_fig.update_layout(
width=width,
height=height,
title_text=title,
**main_title_settings
)
# dummy scatters for resolution colors
for resolution, color in resolution_colors.items():
if resolution not in used_resolutions:
continue
category_fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
name=resolution,
marker=dict(color=color, size=5),
showlegend=True,
legendgroup=resolution,
)
)
category_fig.update_layout(legend=dict(itemsizing="constant"))
# y-axis
if y_range:
category_fig.update_yaxes(range=y_range)
else:
if ASSAY in task_name:
category_fig.update_yaxes(range=[0.96, 1.001])
if CELL_TYPE in task_name:
category_fig.update_yaxes(range=[0.75, 1])
category_fig.update_layout(**main_title_settings)
# Save figure
if logdir:
base_name = f"feature_set_metrics_{metadata_categories[category_idx]}"
if name is not None:
base_name = base_name + f"_{name}"
category_fig.write_html(logdir / f"{base_name}.html")
category_fig.write_image(logdir / f"{base_name}.svg")
category_fig.write_image(logdir / f"{base_name}.png")
category_fig.show()
```
### B - Assay MLP performance
Graph 100kb resolution MLP metrics.
```{python}
#| label: fig1b
#| layout-nrow: 1
#| column: page-left
metrics_fig1b = {name: all_metrics[name] for name in ["hg38_100kb_all_none"]}
metrics_fig1b_1 = {
"hg38_100kb_all_none": {ASSAY: metrics_fig1b["hg38_100kb_all_none"][ASSAY]}
}
graph_feature_set_metrics(
all_metrics=metrics_fig1b_1, # type: ignore
input_sizes=input_sizes,
boxpoints="all",
width=425,
height=400,
y_range=(0.98, 1.001),
)
metrics_fig1b_2 = {
"hg38_100kb_all_none": {CELL_TYPE: metrics_fig1b["hg38_100kb_all_none"][CELL_TYPE]}
}
graph_feature_set_metrics(
all_metrics=metrics_fig1b_2, # type: ignore
input_sizes=input_sizes,
boxpoints="all",
width=425,
height=400,
y_range=(0.93, 1.001),
)
```
Fig. 1B: Distribution of accuracy and F1-score for each of the ten training folds (dots) for the Assay and Biospecimen MLP classifiers.
### C - MLP performance at varying resolution
Graph.
```{python}
#| label: fig1c
#| layout-nrow: 2
#| column: page-left
metrics_fig1c = {name: all_metrics[name] for name in fig1_sets}
graph_feature_set_metrics(
all_metrics=metrics_fig1c, # type: ignore
input_sizes=input_sizes,
boxpoints="all",
width=900,
height=600,
)
```
Fig. 1C-alt: Distribution of accuracy per training fold for different bin resolutions for the Assay and Biospecimen classifiers.\
\
\
Define function `parse_bin_size` to extract a numerical bin size in base pairs.
```{python}
#| label: fig1c-parse-bin-size
def parse_bin_size(feature_set_name: str) -> Optional[float]:
"""
Parses the feature set name to extract a numerical bin size in base pairs.
Handles formats like '100kb', '5mb', 'regulReg', 'gene', 'cpg'.
Returns numerical size (float) or None if unparseable or non-numeric.
Assigns placeholder values for non-genomic-range types if needed,
but for a continuous axis, it's better to return None or filter later.
"""
name_parts = feature_set_name.replace("hg38_", "").split("_")
if not name_parts:
return None
resolution_str = name_parts[0].lower()
# Handle standard genomic ranges
match_kb = re.match(r"(\d+)kb", resolution_str)
if match_kb:
return float(match_kb.group(1)) * 1_000
match_mb = re.match(r"(\d+)mb", resolution_str)
if match_mb:
return float(match_mb.group(1)) * 1_000_000
# Handle non-range types - decide how to represent them.
# Option 1: Return None (they won't be plotted on the numeric axis)
# Option 2: Assign arbitrary numbers (might distort scale)
# Option 3: Could use different marker symbols later if needed
if resolution_str in ["regulatory", "gene", "cpg"]:
return None # Returning None is cleaner for a pure numeric axis
# Fallback for unrecognised formats
try:
# Maybe it's just a number (e.g., representing window size)?
return float(resolution_str)
except ValueError:
return None
```
Define graphing function `graph_feature_set_scatter` to plot performance metrics as a scatter plot instead of bar plot.
```{python}
#| label: fig1c-scatter-function
def graph_feature_set_scatter(
all_metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
input_sizes: Dict[str, int],
logdir: Optional[Path] = None,
metric_to_plot: str = "Accuracy",
name: Optional[str] = None,
metric_range: Optional[Tuple[float, float]] = None,
assay_task_key: str = ASSAY,
sex_task_key: str = SEX,
cell_type_task_key: str = CELL_TYPE,
verbose: bool = True,
) -> None:
"""
Graphs performance metrics as a scatter plot with modifications.
X-axis: Number of Features (log scale).
Y-axis: Average performance metric (e.g., Accuracy, F1_macro) across folds.
Vertical lines indicate the min/max range across folds.
Color: Bin Size (bp, log scale).
Args:
all_metrics: Nested dict {feature_set: {task_name: {split_name: metric_dict}}}.
input_sizes: Dict {feature_set: num_features}.
logdir: Directory to save figures. If None, display only.
metric_to_plot: The metric key to use for the Y-axis ('Accuracy', 'F1_macro').
name: Optional suffix for figure titles and filenames.
metric_range: Optional tuple (min, max) to set the Y-axis range.
assay_task_key: Key used for the assay prediction task.
sex_task_key: Key used for the sex prediction task.
cell_type_task_key: Key used for the cell type prediction task.
"""
if metric_to_plot not in ["Accuracy", "F1_macro"]:
raise ValueError("metric_to_plot must be 'Accuracy' or 'F1_macro'")
# --- Standard Name Handling (simplified from original) ---
non_standard_names = {ASSAY: f"{ASSAY}_11c", SEX: f"{SEX}_w-mixed"}
# --- Find reference and task names ----
reference_hdf5_type = next(iter(all_metrics), None)
if reference_hdf5_type is None or not all_metrics.get(reference_hdf5_type):
print(
"Warning: Could not determine tasks from all_metrics. Trying default tasks."
)
cleaned_metadata_categories = {assay_task_key, sex_task_key, cell_type_task_key}
else:
metadata_categories = list(all_metrics[reference_hdf5_type].keys())
cleaned_metadata_categories = set()
for cat in metadata_categories:
original_name = cat
for standard, non_standard in non_standard_names.items():
if cat == non_standard:
original_name = standard
break
cleaned_metadata_categories.add(original_name)
# --- Define Bin size categories and Colors ---
bin_category_names = ["1Kb", "10Kb", "100Kb", "1Mb", "10Mb"]
bin_category_values = [1000, 10000, 100 * 1000, 1000 * 1000, 10000 * 1000]
discrete_colors = px.colors.sequential.Viridis_r
color_map = {
name: discrete_colors[i * 2] for i, name in enumerate(bin_category_names)
}
if verbose:
print(f"Plotting for tasks: {list(cleaned_metadata_categories)}")
for category_name in cleaned_metadata_categories:
plot_data_points = []
for feature_set_name_orig in all_metrics.keys():
try:
num_features = input_sizes[feature_set_name_orig]
except KeyError as e:
raise ValueError(
f"Feature set '{feature_set_name_orig}' not found in input_sizes"
) from e
# Parse Bin Size
bin_size = parse_bin_size(feature_set_name_orig)
if bin_size is None:
print(
f"Skipping {feature_set_name_orig}, could not parse numeric bin size."
)
continue
# 3. Get Metric Values (Average, Min, Max)
tasks_dicts = all_metrics[feature_set_name_orig]
# --- Task Name Lookup ---
# 1. Try the standard category name first
# 2. If standard name not found, use non-standard name
task_dict = None
task_name = category_name
if category_name in tasks_dicts:
task_dict = tasks_dicts[category_name]
else:
non_standard_task_name = non_standard_names.get(category_name)
if non_standard_task_name and non_standard_task_name in tasks_dicts:
task_name = non_standard_task_name
task_dict = tasks_dicts[non_standard_task_name]
if task_dict is None:
raise ValueError(
f"Task '{category_name}' not found in feature set '{feature_set_name_orig}'"
)
# --- End Task Name Lookup ---
# Calculate average, min, max metric value across splits
try:
metric_values = []
for split, split_data in task_dict.items():
if metric_to_plot in split_data:
metric_values.append(split_data[metric_to_plot])
else:
print(
f"Warning: Metric '{metric_to_plot}' not found in split '{split}' for {feature_set_name_orig} / {task_name}"
)
if not metric_values:
print(
f"Warning: No metric values found for {feature_set_name_orig} / {task_name} / {metric_to_plot}"
)
continue
avg_metric = np.mean(metric_values)
min_metric = np.min(metric_values)
max_metric = np.max(metric_values)
except Exception as e: # pylint: disable=broad-except
raise ValueError(
f"Error calculating metrics for {feature_set_name_orig} / {task_name}: {e}"
) from e
# Clean feature set name for hover text
clean_name = feature_set_name_orig.replace("_none", "").replace("hg38_", "")
clean_name = re.sub(r"\_[\dmkb]+\_coord", "", clean_name)
# Store data for this point
plot_data_points.append(
{
"bin_size": bin_size,
"num_features": num_features,
"metric_value": avg_metric,
"min_metric": min_metric, # For error bar low
"max_metric": max_metric, # For error bar high
"name": clean_name,
"raw_name": feature_set_name_orig,
}
)
if not plot_data_points:
raise ValueError(
f"No suitable data points found to plot for task: {category_name}"
)
# --- Determine Marker Symbols ---
marker_symbols = []
default_symbol = "circle"
random_symbol = "cross"
for p in plot_data_points:
if "random" in p["raw_name"]:
marker_symbols.append(random_symbol)
else:
marker_symbols.append(default_symbol)
# --- Group Data by Category ---
points_by_category = {name: [] for name in bin_category_names}
for i, point_data in enumerate(plot_data_points):
bin_size = point_data["bin_size"]
assigned_category = None
for cat_name, cat_value in zip(bin_category_names, bin_category_values):
if bin_size == cat_value:
assigned_category = cat_name
break
else:
raise ValueError(f"Could not find category for bin size: {bin_size}")
points_by_category[assigned_category].append(
{
"x": point_data["num_features"], # X is Num Features
"y": point_data["metric_value"],
"error_up": point_data["max_metric"] - point_data["metric_value"],
"error_down": point_data["metric_value"] - point_data["min_metric"],
"text": point_data["name"],
"customdata": [
point_data["min_metric"],
point_data["max_metric"],
point_data["bin_size"],
], # Keep bin size for hover
"symbol": marker_symbols[i], # Assign symbol determined earlier
}
)
# --- Create Figure and Add Traces PER CATEGORY ---
fig = go.Figure()
traces = []
for cat_name in bin_category_names: # Iterate in defined order for legend
points_in_cat = points_by_category[cat_name]
if not points_in_cat:
continue
category_color = color_map[cat_name]
# Extract data for all points in this category
x_vals = [p["x"] for p in points_in_cat]
y_vals = [p["y"] for p in points_in_cat]
error_up_vals = [p["error_up"] for p in points_in_cat]
error_down_vals = [p["error_down"] for p in points_in_cat]
text_vals = [p["text"] for p in points_in_cat]
customdata_vals = [p["customdata"] for p in points_in_cat]
symbols_vals = [p["symbol"] for p in points_in_cat]
trace = go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
name=cat_name,
showlegend=False,
legendgroup=cat_name, # Group legend entries
marker=dict(
color=category_color,
size=15,
symbol=symbols_vals,
line=dict(width=1, color="DarkSlateGrey"),
),
error_y=dict(
type="data",
symmetric=False,
array=error_up_vals,
arrayminus=error_down_vals,
visible=True,
thickness=1.5,
width=15,
color=category_color,
),
text=text_vals,
customdata=customdata_vals,
hovertemplate=(
f"<b>%{{text}}</b><br><br>"
f"Num Features: %{{x:,.0f}}<br>"
f"{metric_to_plot}: %{{y:.4f}}<br>"
f"Bin Size: %{{customdata:,.0f}} bp<br>"
f"{metric_to_plot} Range (10-fold): %{{customdata:.4f}} - %{{customdata:.4f}}"
"<extra></extra>"
),
)
traces.append(trace)
fig.add_traces(traces)
# --- Add Legend ---
# Add a hidden scatter trace with square markers for legend
for cat_name in bin_category_names:
category_color = color_map[cat_name]
legend_trace = go.Scatter(
x=[None],
y=[None],
mode="markers",
name=cat_name,
marker=dict(
color=category_color,
size=15,
symbol="square",
line=dict(width=1, color="DarkSlateGrey"),
),
legendgroup=cat_name,
showlegend=True,
)
fig.add_trace(legend_trace)
# --- Update layout ---
title_name = category_name.replace(CELL_TYPE, "biospecimen")
plot_title = f"{metric_to_plot} vs Number of Features - {title_name}"
if name:
plot_title += f" - {name}"
xaxis_title = "Number of Features (log scale)"
xaxis_type = "log"
yaxis_title = metric_to_plot.replace("_", " ").title()
yaxis_type = "linear"
fig.update_layout(
xaxis_title=xaxis_title,
yaxis_title=yaxis_title,
xaxis_type=xaxis_type,
yaxis_type=yaxis_type,
yaxis_range=metric_range,
width=500,
height=500,
hovermode="closest",
legend_title_text="Bin Size",
title_text=plot_title,
**main_title_settings
)
if category_name == CELL_TYPE:
fig.update_yaxes(range=[0.75, 1.005])
elif category_name == ASSAY:
fig.update_yaxes(range=[0.96, 1.001])
# --- Save or show figure ---
if logdir:
logdir.mkdir(parents=True, exist_ok=True)
# Include "modified" or similar in filename to distinguish
base_name = f"feature_scatter_MODIFIED_v2_{category_name}_{metric_to_plot}"
if name:
base_name += f"_{name}"
html_path = logdir / f"{base_name}.html"
svg_path = logdir / f"{base_name}.svg"
png_path = logdir / f"{base_name}.png"
print(f"Saving modified plot for {category_name} to {html_path}")
fig.write_html(html_path)
fig.write_image(svg_path)
fig.write_image(png_path)
fig.show()
```
Graph
```{python}
#| label: fig1c-scatter-plot
#| column: screen-inset-left
#| layout-ncol: 2
#| layout-nrow: 2
for metric in ["Accuracy", "F1_macro"]:
graph_feature_set_scatter(
all_metrics=metrics_fig1c, # type: ignore
input_sizes=input_sizes,
metric_to_plot=metric,
verbose=False,
)
```
Fig. 1C: Distribution of accuracy per training fold for different bin resolutions for the Assay and Biospecimen classifiers. The circles represent the means and the whiskers the min and max values of the ten training folds.
### D - Confusion matrix
Define function `create_confusion_matrix` to create and show a confusion matrix.
```{python}
#| label: fig1d-cm-function
def create_confusion_matrix(
df: pd.DataFrame,
name: str = "confusion_matrix",
logdir: Path | None = None,
min_pred_score: float = 0,
majority: bool = False,
verbose:bool=False
) -> None:
"""Create a confusion matrix for the given DataFrame and save it to the logdir.
Args:
df (pd.DataFrame): The DataFrame containing the results.
logdir (Path): The directory path for saving the figures.
name (str): The name for the saved figures.
min_pred_score (float): The minimum prediction score to consider.
majority (bool): Whether to use majority vote (uuid-wise) for the predicted class.
"""
# Compute confusion matrix
classes = sorted(df["True class"].unique())
if "Max pred" not in df.columns:
df["Max pred"] = df[classes].max(axis=1) # type: ignore
filtered_df = df[df["Max pred"] > min_pred_score]
if majority:
# Majority vote for predicted class
groupby_uuid = filtered_df.groupby(["uuid", "True class", "Predicted class"])[
"Max pred"
].aggregate(["size", "mean"])
if groupby_uuid["size"].max() > 3:
raise ValueError("More than three predictions for the same uuid.")
groupby_uuid = groupby_uuid.reset_index().sort_values(
["uuid", "True class", "size"], ascending=[True, True, False]
)
groupby_uuid = groupby_uuid.drop_duplicates(
subset=["uuid", "True class"], keep="first"
)
filtered_df = groupby_uuid
confusion_mat = sk_cm(
filtered_df["True class"], filtered_df["Predicted class"], labels=classes
)
mat_writer = ConfusionMatrixWriter(labels=classes, confusion_matrix=confusion_mat)
if logdir is None:
logdir = Path(tempfile.gettempdir())
files = mat_writer.to_all_formats(logdir, name=f"{name}_n{len(filtered_df)}")
if verbose:
print(f"Saved confusion matrix to {logdir}:")
for file in files:
print(Path(file).name)
for file in files:
if "png" in file.name:
scale = 0.6
display(Image(filename=file, width=1250*scale, height=1000*scale))
```
Prepare prediction data for confusion matrix.
```{python}
#| label: fig1d-data-prep
assay_split_dfs = split_results_handler.gather_split_results_across_methods(
results_dir=data_dir_100kb, label_category=ASSAY, only_NN=True
)
concat_assay_df = split_results_handler.concatenate_split_results(assay_split_dfs)["NN"]
df_with_meta = metadata_handler.join_metadata(concat_assay_df, metadata_v2) # type: ignore
if "Predicted class" not in df_with_meta.columns:
raise ValueError("`Predicted class` not in DataFrame")
classifier_name = "MLP"
min_pred_score = 0
majority = False
name = f"{classifier_name}_pred>{min_pred_score}"
logdir = base_fig_dir / "fig1_EpiAtlas_assay" / "fig1_supp_D-assay_c11_confusion_matrices"
if majority:
logdir = logdir / "per_uuid"
else:
logdir = logdir / "per_file"
logdir.mkdir(parents=True, exist_ok=True)
```
Graph.
```{python}
#| label: fig1d-create-cm
#| column: body-outset-left
create_confusion_matrix(
df=df_with_meta,
min_pred_score=min_pred_score,
majority=majority,
)
```
Fig. 1D: Confusion matrix aggregating the cross-validation folds (therefore showing all files) without applying a prediction score threshold. RNA-seq and WGBS data were both separated according to two protocols during initial training (but combined thereafter to nine assays).
### E - Mislabeled target assays
{width=.column-body}
Fig. 1E: Genome browser representation showing in black the datasets swap between H3K4me3 and H3K27ac for IHECRE00001897 in the metadata freeze v1.0, along with typical correct datasets over a representative region.
## Supplementary Figure 1
More detailled performance of EpiClass Assay and Biospecimen classifiers.
### A,B - All classifiers metrics on EpiAtlas
Fig. 1A,B data points are included in these two graphs (MLP data points).\
\
\
Define graphing function `plot_multiple_models_split_metrics`.
```{python}
#| label: supp-fig1ab-plot-function
def plot_multiple_models_split_metrics(
split_metrics: Dict[str, Dict[str, Dict[str, float]]],
label_category: str,
logdir: Path | None = None,
filename: str = "fig1_all_classifiers_metrics",
) -> None:
"""Render to box plots the metrics per classifier/models and split, each in its own subplot.
Args:
split_metrics: A dictionary containing metric scores for each classifier and split.
label_category: The label category for the classification task.
name: The name of the figure.
logdir: The directory to save the figure to. If None, the figure is only displayed.
Returns:
None: Displays the figure and saves it to the logdir if provided.
"""
metrics = ["Accuracy", "F1_macro", "AUC_micro", "AUC_macro"]
classifier_names = list(next(iter(split_metrics.values())).keys())
classifier_names = ["NN", "LR", "LGBM", "LinearSVC", "RF"]
# Create subplots, one row for each metric
fig = make_subplots(
rows=1,
cols=len(metrics),
subplot_titles=metrics,
horizontal_spacing=0.075,
)
for i, metric in enumerate(metrics):
for classifier in classifier_names:
values = [split_metrics[split][classifier][metric] for split in split_metrics]
if classifier == "NN":
classifier = "MLP"
fig.add_trace(
go.Box(
y=values,
name=classifier,
line=dict(color="black", width=1.5),
marker=dict(size=3, color="black"),
boxmean=True,
boxpoints="all", # or "outliers" to show only outliers
pointpos=-1.4,
showlegend=False,
width=0.5,
hovertemplate="%{text}",
text=[
f"{split}: {value:.4f}"
for split, value in zip(split_metrics, values)
],
),
row=1,
col=i + 1,
)
fig.update_layout(
title_text=f"{label_category} classification",
boxmode="group",
**main_title_settings,
)
# Adjust y-axis
if label_category == ASSAY:
range_acc = [0.95, 1.001]
range_AUC = [0.992, 1.0001]
elif label_category == CELL_TYPE:
range_acc = [0.81, 1]
range_AUC = [0.96, 1]
else:
range_acc = [0.6, 1.001]
range_AUC = [0.9, 1.0001]
fig.update_layout(
yaxis=dict(range=range_acc),
yaxis2=dict(range=range_acc),
yaxis3=dict(range=range_AUC),
yaxis4=dict(range=range_AUC),
height=450,
)
fig.update_layout(margin=dict(l=20, r=20))
# Save figure
if logdir:
fig.write_image(logdir / f"{filename}.svg")
fig.write_image(logdir / f"{filename}.png")
fig.write_html(logdir / f"{filename}.html")
fig.show()
```
Graph.
```{python}
#| label: supp-fig1ab-run-plots
#| column: page-left
#| layout-nrow: 2
merge_assays = True
for label_category in [ASSAY, CELL_TYPE]:
all_split_dfs = split_results_handler.gather_split_results_across_methods(
results_dir=data_dir_100kb,
label_category=label_category,
only_NN=False,
)
if merge_assays and label_category == ASSAY:
for split_name, split_dfs in all_split_dfs.items():
for classifier_type, df in split_dfs.items():
split_dfs[classifier_type] = merge_similar_assays(df)
split_metrics = split_results_handler.compute_split_metrics(all_split_dfs)
plot_multiple_models_split_metrics(
split_metrics,
label_category=label_category,
)
```
Supplementary Figure 1A,B: Distribution of performance scores (accuracy, F1 as well as micro and macro AUROC) per training fold (dots) for each machine learning approach used for training on the Assay (A) and Biospecimen (B) metadata. Micro-averaging aggregates contributions from all classes (global true positive rate and false positive rate); macro-averaging averages the true positive rate from each class. Dashed lines represent means, solid lines the medians, boxes the quartiles, and
whiskers the farthest points within 1.5× the interquartile range.\
\
\
**Going forward, all results are for MLP classifiers.**
### C - ROC curves
Define graphing function `plot_roc_curves`. Computes macro-average ROC curves manually.
```{python}
#| label: supp-fig1c-roc-function
def plot_roc_curves(
results_df: pd.DataFrame,
label_category: str,
logdir: Path | None = None,
name: str = "roc_curve",
title: str | None = None,
colors_dict: Dict | None = None, # Optional specific colors
verbose: bool = False,
) -> None:
"""
Generates and plots ROC curves for multi-class classification results using Plotly.
Calculates and plots individual class ROC curves, micro-average, and macro-average ROC curves.
Args:
results_df (pd.DataFrame): DataFrame with true labels and prediction probabilities for each class.
Must contain the `label_category` column (e.g., 'True class')
and probability columns named after each class.
label_category (str): The column name containing the true labels (e.g., 'True class', ASSAY, CELL_TYPE).
logdir (Path | None): Directory to save the figure. If None, only displays the figure.
name (str): Base name for saved files (e.g., "supp_fig1e").
title (str | None): Title suffix for the plot. If None, a default title based on label_category is used.
colors_dict (Dict | None): Optional dictionary mapping class names to colors. If None or a class
is missing, default Plotly colors are used.
"""
df = results_df.copy()
true_label_col = "True class" # Assuming 'True class' holds the ground truth labels
if true_label_col not in df.columns:
raise ValueError(f"True label column '{true_label_col}' not found in DataFrame.")
classes = sorted(df[true_label_col].unique())
if verbose:
print(f"Using classes: {classes}")
n_classes = len(classes)
if n_classes < 2:
print(
f"Warning: Only {n_classes} class found after processing. Cannot generate ROC curve."
)
return
# Check if probability columns exist for all determined classes
missing_cols = [c for c in classes if c not in df.columns]
if missing_cols:
raise ValueError(f"Missing probability columns for classes: {missing_cols}")
# Binarize the true labels against the final set of classes
try:
y_true = label_binarize(df[true_label_col], classes=classes)
except ValueError as e:
raise ValueError(
f"Error binarizing labels for classes {classes}. Check if all labels in '{true_label_col}' are included in 'classes'."
) from e
if n_classes == 2 and y_true.shape[1] == 1:
# Adjust for binary case where label_binarize might return one column
y_true = np.hstack((1 - y_true, y_true)) # type: ignore
elif y_true.shape[1] != n_classes:
raise ValueError(
f"Binarized labels shape {y_true.shape} does not match number of classes {n_classes}"
)
# Get the predicted probabilities for each class
# Ensure columns are in the same order as 'classes'
y_score = df[classes].values
# --- Compute ROC curve and ROC area for each class ---
fpr = dict()
tpr = dict()
roc_auc = dict()
for i, class_name in enumerate(classes):
try:
fpr[class_name], tpr[class_name], _ = roc_curve(
y_true=y_true[:, i], y_score=y_score[:, i] # type: ignore
)
roc_auc[class_name] = auc(fpr[class_name], tpr[class_name])
except ValueError as e:
raise ValueError("Could not compute ROC for class {class_name}.") from e
# --- Compute micro-average ROC curve and ROC area ---
try:
fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_score.ravel()) # type: ignore
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
except ValueError as e:
raise ValueError("Could not compute micro-average ROC.") from e
# --- Compute macro-average ROC curve and ROC area ---
try:
# Aggregate all false positive rates
all_fpr = np.unique(
np.concatenate(
[fpr[class_name] for class_name in classes if class_name in fpr]
)
)
# Interpolate all ROC curves at these points
mean_tpr = np.zeros_like(all_fpr)
valid_classes_count = 0
for class_name in classes:
if class_name in fpr and class_name in tpr:
mean_tpr += np.interp(all_fpr, fpr[class_name], tpr[class_name])
valid_classes_count += 1
# Average it and compute AUC
if valid_classes_count > 0:
mean_tpr /= valid_classes_count
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
else:
raise ValueError("No valid classes found for macro averaging.")
except ValueError as e:
raise ValueError("Could not compute macro-average ROC.") from e
# --- Plot all ROC curves ---
fig = go.Figure()
# Plot diagonal line for reference
fig.add_shape(
type="line", line=dict(dash="dash", color="grey", width=1), x0=0, x1=1, y0=0, y1=1
)
# Define colors for plotting
color_cycle = px.colors.qualitative.Plotly # Default cycle
plot_colors = {}
for i, cls_name in enumerate(classes):
if colors_dict and cls_name in colors_dict:
plot_colors[cls_name] = colors_dict[cls_name]
else:
plot_colors[cls_name] = color_cycle[i % len(color_cycle)]
# Plot Micro-average ROC curve first (often plotted thicker/dashed)
fig.add_trace(
go.Scatter(
x=fpr["micro"],
y=tpr["micro"],
mode="lines",
name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.5f})',
line=dict(color="deeppink", width=3, dash="dash"),
hoverinfo="skip", # Less important for hover usually
)
)
# Plot Macro-average ROC curve
fig.add_trace(
go.Scatter(
x=fpr["macro"],
y=tpr["macro"],
mode="lines",
name=f'Macro-average ROC (AUC = {roc_auc["macro"]:.5f})',
line=dict(color="navy", width=3, dash="dash"),
hoverinfo="skip",
)
)
# Plot individual class ROC curves
for class_name in classes:
if class_name not in fpr or class_name not in tpr or class_name not in roc_auc:
continue # Skip if calculation failed
fig.add_trace(
go.Scatter(
x=fpr[class_name],
y=tpr[class_name],
mode="lines",
name=f"{class_name} (AUC = {roc_auc[class_name]:.5f})",
line=dict(width=1.5, color=plot_colors.get(class_name)),
hovertemplate=f"<b>{class_name}</b><br>FPR=%{{x:.5f}}<br>TPR=%{{y:.5f}}<extra></extra>", # Show class name and values on hover
)
)
# --- Update layout ---
base_title = f"ROC Curves<br>{label_category}"
plot_title = f"{base_title} - {title}" if title else base_title
title_settings=dict(
yanchor="top",
yref="paper",
y=0.97,
xanchor="center",
xref="paper",
x=0.5,
)
fig.update_layout(
title=title_settings,
title_text=plot_title,
xaxis_title="False Positive Rate (1 - Specificity)",
yaxis_title="True Positive Rate (Sensitivity)",
xaxis=dict(range=[0.0, 1.0], constrain="domain"), # Ensure axes range 0-1
yaxis=dict(
range=[0.0, 1.01], scaleanchor="x", scaleratio=1, constrain="domain"
), # Make it square-ish, slight top margin
width=800,
height=650,
hovermode="closest",
legend=dict(
traceorder="reversed", # Show averages first in legend
title="Classes & Averages",
font=dict(size=9),
itemsizing="constant",
y=0.8,
yref="paper",
),
margin=dict(l=60, r=30, t=0, b=0),
)
# --- Save figure if logdir is provided ---
if logdir:
logdir.mkdir(parents=True, exist_ok=True) # Ensure directory exists
filename_base = f"{name}_{label_category}_roc"
filepath_base = logdir / filename_base
fig.write_html(f"{filepath_base}.html")
fig.write_image(f"{filepath_base}.svg", width=800, height=750)
fig.write_image(f"{filepath_base}.png", width=800, height=750, scale=2)
print(f"Saved ROC curve plots for {label_category} to {logdir}")
print(f" -> {filename_base}.html / .svg / .png")
fig.show()
```
Prepare assay data for plotting.
```{python}
#| label: supp-fig1c-roc-assay-data
data_dir = (
mixed_data_dir
/ "hg38_100kb_all_none"
/ f"{ASSAY}_1l_3000n"
/ "11c"
/ "10fold-oversampling"
)
PathChecker.check_directory(data_dir)
dfs = split_results_handler.read_split_results(data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1) # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)
df = merge_similar_assays(concat_df_w_meta.copy())
```
Graph assay results.
```{python}
#| label: supp-fig1c-roc-assay-plot
plot_roc_curves(
results_df=df.copy(),
label_category=ASSAY,
title="Aggregated 10fold", # Title suffix
colors_dict=assay_colors,
verbose=False,
)
```
Prepare biospecimen data for plotting.
```{python}
#| label: supp-fig1c-roc-celltype-data
data_dir = (
mixed_data_dir
/ "hg38_100kb_all_none"
/ f"{CELL_TYPE}_1l_3000n"
/ "10fold-oversampling"
)
PathChecker.check_directory(data_dir)
dfs = split_results_handler.read_split_results(data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1) # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)
```
Graph biospecimen results.
```{python}
#| label: supp-fig1c-roc-celltype-plot
plot_roc_curves(
results_df=concat_df_w_meta,
label_category=CELL_TYPE,
title="Aggregated 10fold", # Title suffix
colors_dict=cell_type_colors,
verbose=False,
)
```
Supplementary Figure 1C: ROC curves from aggregated cross-validation results for the Assay and Biospecimen classifiers. Curves for each class are
computed in a one-vs-rest scheme.
### D - Alternative signal pre-processing
Define graphing function `create_blklst_graphs`.
```{python}
#| label: supp-fig1d-blklst-function
def create_blklst_graphs(
feature_set_metrics_dict: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
logdir: Path | None = None,
) -> List[go.Figure]:
"""Create boxplots for blacklisted related feature sets.
Args:
feature_set_metrics_dict (Dict[str, Dict[str, Dict[str, Dict[str, float]]]]): The dictionary containing all metrics for all blklst related feature sets.
format: {feature_set: {task_name: {split_name: metric_dict}}}
logdir (Path, Optional): The directory to save the figure to. If None, the figure is only displayed.
"""
figs = []
# Assume names exist in all feature sets
task_names = list(feature_set_metrics_dict.values())[0].keys()
traces_names_dict = {
"hg38_100kb_all_none": "observed",
"hg38_100kb_all_none_0blklst": "0blklst",
"hg38_100kb_all_none_0blklst_winsorized": "0blklst_winsorized",
}
for task_name in task_names:
category_fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=False,
subplot_titles=["Accuracy", "F1-score (macro)"],
horizontal_spacing=0.1,
)
for feature_set_name, tasks_dicts in feature_set_metrics_dict.items():
task_dict = tasks_dicts[task_name]
trace_name = traces_names_dict[feature_set_name]
# Accuracy
metric = "Accuracy"
y_vals = [task_dict[split][metric] for split in task_dict] # type: ignore
hovertext = [
f"{split}: {metrics_dict[metric]:.4f}" # type: ignore
for split, metrics_dict in task_dict.items()
]
category_fig.add_trace(
go.Box(
y=y_vals,
name=trace_name,
boxmean=True,
boxpoints="all",
showlegend=False,
marker=dict(size=3, color="black"),
line=dict(width=1, color="black"),
hovertemplate="%{text}",
text=hovertext,
),
row=1,
col=1,
)
metric = "F1_macro"
y_vals = [task_dict[split][metric] for split in task_dict] # type: ignore
hovertext = [
f"{split}: {metrics_dict[metric]:.4f}" # type: ignore
for split, metrics_dict in task_dict.items()
]
category_fig.add_trace(
go.Box(
y=y_vals,
name=trace_name,
boxmean=True,
boxpoints="all",
showlegend=False,
marker=dict(size=3, color="black"),
line=dict(width=1, color="black"),
hovertemplate="%{text}",
text=hovertext,
),
row=1,
col=2,
)
category_fig.update_xaxes(
categoryorder="array",
categoryarray=list(traces_names_dict.values()),
)
category_fig.update_yaxes(range=[0.9, 1.001])
category_fig.update_layout(
title_text=task_name,
height=600,
width=500,
**main_title_settings
)
# Save figure
if logdir:
task_name = task_name.replace("_1l_3000n-10fold", "")
base_name = f"metrics_{task_name}"
category_fig.write_html(logdir / f"{base_name}.html")
category_fig.write_image(logdir / f"{base_name}.svg")
category_fig.write_image(logdir / f"{base_name}.png")
figs.append(category_fig)
return figs
```
Prepare paths.
```{python}
#| label: supp-fig1d-data-prep
include_sets = [
"hg38_100kb_all_none",
"hg38_100kb_all_none_0blklst",
"hg38_100kb_all_none_0blklst_winsorized",
]
results_folder_blklst = base_data_dir / "training_results" / "2023-01-epiatlas-freeze"
PathChecker.check_directory(results_folder_blklst)
```
Compute metrics.
```{python}
#| label: supp-fig1d-metrics
# Select 10-fold oversampling runs
# expected result shape: {feature_set: {task_name: {split_name: metrics_dict}}}
all_metrics_blklst: Dict[
str, Dict[str, Dict[str, Dict[str, float]]]
] = split_results_handler.obtain_all_feature_set_data(
return_type="metrics",
parent_folder=results_folder_blklst,
merge_assays=True,
include_categories=[ASSAY, CELL_TYPE],
include_sets=include_sets,
oversampled_only=False,
verbose=False,
) # type: ignore
```
Graph.
```{python}
#| label: supp-fig1dplot
#| layout-ncol: 2
#| column: screen-inset-left
figs = create_blklst_graphs(all_metrics_blklst)
figs[0].show()
figs[1].show()
```
Supplementary Figure 1D: Distribution of accuracy and F1-score per training fold (dots) for the Assay and Biospecimen classifiers after removing signal from blacklisted regions and applying winsorization of 0.1%. Dashed lines represent means, solid lines the medians, boxes the quartiles, and whiskers the farthest points within 1.5× the interquartile range.
### E-G - Distribution of average prediction scores per assay
- E: Assay training 10-fold cross-validation
- F: Assay complete training (mixed tracks), predictions on imputed data (all pval)
- G: Biospecimen 10-fold cross-validation
\
Define graphing function `plot_prediction_scores_distribution`.
```{python}
#| label: supp-fig1efg-plot-function
#| echo: false
def plot_prediction_scores_distribution(
results_df: pd.DataFrame,
merge_assay_pairs: bool = True,
logdir: Path | None = None,
name: str = "prediction_score_distribution",
group_by_column: str = "True class",
min_y: float = 0.7,
use_aggregate_vote: bool = True,
title: str | None = None,
) -> None:
"""
Creates a Plotly figure with violin plots and associated scatter plots for each group.
Supports both aggregated and non-aggregated data visualization with enhanced styling.
Args:
results_df (pd.DataFrame): DataFrame containing prediction results and metadata
merge_assay_pairs (bool): Whether to merge similar assays (mrna/rna, wgbs-pbat/wgbs-standard)
logdir (Path | None): Directory to save figures. If None, only displays the figure
name (str): Base name for saved files
group_by_column (str): Column name to use for grouping traces
merge_similar_assays (bool): Whether to merge similar assays (mrna/rna, wgbs-pbat/wgbs-standard)
min_y (float): Minimum y-axis value
use_aggregate_vote (bool): If True, aggregate by EpiRR. If False, use individual predictions
title (str | None): Additional title text to append
"""
fig = go.Figure()
if merge_assay_pairs:
try:
results_df = merge_similar_assays(results_df)
except ValueError as e:
print(f"Skipping assay merging: {e}")
# Group ordering
if group_by_column == ASSAY and merge_assay_pairs:
group_labels = ASSAY_ORDER
else:
group_labels = sorted(set(results_df[group_by_column].unique()))
group_index = {label: i for i, label in enumerate(group_labels)}
# Colors for each group
if group_by_column == ASSAY:
colors = assay_colors
else:
grey = "rgba(237, 231, 225, 1)"
colors = {label: grey for label in group_labels}
for label in group_labels:
sub_df = results_df[results_df[group_by_column] == label]
if use_aggregate_vote:
# Aggregate by EpiRR with majority voting
groupby = sub_df.groupby(["EpiRR", "Predicted class", "True class"])[
"Max pred"
].aggregate(["size", "mean"])
groupby = groupby.reset_index().sort_values(
["EpiRR", "size"], ascending=[True, False]
)
groupby = groupby.drop_duplicates(subset="EpiRR", keep="first")
assert groupby["EpiRR"].is_unique
mean_pred = groupby["mean"]
else:
# Use individual predictions
mean_pred = sub_df["Max pred"]
# Add violin plot
fig.add_trace(
go.Violin(
x=[group_index[label]] * len(mean_pred),
y=mean_pred,
name=label,
spanmode="hard",
box_visible=True,
meanline_visible=True,
points="all",
marker=dict(size=1, color="black"),
fillcolor=colors[label],
line=dict(width=0.5, color="white"),
showlegend=True,
)
)
# Update layout
title_text = "Prediction Score Distribution"
if use_aggregate_vote:
title_text += " (EpiRR majority vote)"
if title:
title_text += f" - {title}"
fig.update_layout(
title=main_title_settings["title"],
title_text=title_text,
yaxis_title=(
"Prediction Score"
if not use_aggregate_vote
else "Avg. Prediction Score (majority class)"
),
xaxis_title=group_by_column,
yaxis_range=[min_y, 1.001],
xaxis=dict(
tickvals=list(group_index.values()),
ticktext=list(group_index.keys()),
),
legend=dict(
title_text="Legend",
itemsizing="constant",
orientation="h",
),
height=650,
width=900,
margin=dict(t=60, r=10, l=10)
)
# Save figure if logdir is provided
if logdir:
filename = f"{name}_epirr" if use_aggregate_vote else name
fig.write_html(logdir / f"{filename}.html")
fig.write_image(logdir / f"{filename}.svg")
fig.write_image(logdir / f"{filename}.png")
fig.show()
```
Supplementary Figure 1E-G: Distribution of average prediction score per file (dots) for the majority-vote class (up to three track type files) (E, F) or individual file (G), from the MLP approach for the Assay (E, G) and Biospecimen classifiers (F), using aggregated cross-validation results from observed data (E, F) or results from the classifier trained on all observed data and applied to imputed data from EpiATLAS (G). Dashed lines represent means, solid lines the medians, boxes the quartiles, and whiskers the farthest points within 1.5× the interquartile range, with a violin representation on top.
#### E - Assay
Gather prediction scores.
```{python}
#| label: supp-fig1e-data
data_dir = (
mixed_data_dir
/ "hg38_100kb_all_none"
/ f"{ASSAY}_1l_3000n"
/ "11c"
/ "10fold-oversampling"
)
PathChecker.check_directory(data_dir)
dfs = split_results_handler.read_split_results(data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1) # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)
```
Graph.
```{python}
#| label: supp-fig1e-plot
#| column: page-left
plot_prediction_scores_distribution(
results_df=concat_df_w_meta,
group_by_column=ASSAY,
merge_assay_pairs=True,
min_y=0.7,
title="11 classes assay training<br>Prediction scores for 10-fold cross-validation",
)
```
#### F - Biospecimen
Gather prediction scores.
```{python}
#| label: supp-fig1f-data
data_dir = data_dir_100kb / f"{CELL_TYPE}_1l_3000n" / "10fold-oversampling"
PathChecker.check_directory(data_dir)
dfs = split_results_handler.read_split_results(data_dir)
concat_df: pd.DataFrame = split_results_handler.concatenate_split_results(dfs, depth=1) # type: ignore
concat_df = split_results_handler.add_max_pred(concat_df)
concat_df_w_meta = metadata_handler.join_metadata(concat_df, metadata_v2)
concat_df_w_meta.replace({ASSAY: ASSAY_MERGE_DICT}, inplace=True)
```
Graph.
```{python}
#| label: supp-fig1f-plot
#| column: page-left
plot_prediction_scores_distribution(
results_df=concat_df_w_meta,
group_by_column=ASSAY,
min_y=0,
title="Biospecimen training<br>Prediction scores for 10-fold cross-validation",
)
```
#### G - Assay imputed
Gather imputed signal metadata.
```{python}
#| label: supp-fig1g-imputed-metadata
metadata_path = (
paper_dir
/ "data"
/ "metadata"
/ "epiatlas"
/ "imputed"
/ "hg38_epiatlas_imputed_pval_chip_2024-02.json"
)
metadata_imputed: pd.DataFrame = metadata_handler.load_any_metadata(metadata_path, as_dataframe=True) # type: ignore
```
Gather prediction scores.
```{python}
#| label: supp-fig1g-imputed-data
data_dir = (
gen_data_dir
/ "hg38_100kb_all_none"
/ f"{ASSAY}_1l_3000n"
/ "11c"
/ "complete_no_valid_oversample"
/ "predictions"
/ "epiatlas_imputed"
/ "ChIP"
)
PathChecker.check_directory(data_dir)
df_pred = pd.read_csv(
data_dir / "complete_no_valid_oversample_prediction.csv",
index_col=0,
)
```
Prepare dataframe for graphing.
```{python}
#| label: supp-fig1g-add-max-pred
assay_classes = list(metadata_v2_df[ASSAY].unique())
df_pred = split_results_handler.add_max_pred(df_pred, expected_classes=assay_classes)
augmented_df = pd.merge(df_pred, metadata_imputed, left_index=True, right_on="md5sum")
augmented_df["True class"] = augmented_df[ASSAY]
print("Number of files per assay:")
print(augmented_df["True class"].value_counts(dropna=False).to_string())
```
\
\
Graph.
```{python}
#| label: supp-fig1g-plot
#| column: page-left
plot_prediction_scores_distribution(
results_df=augmented_df,
group_by_column=ASSAY,
merge_assay_pairs=True,
min_y=0.79,
use_aggregate_vote=False,
title="Complete 11c assay classifier<br>inference on imputed data",
)
```
### H,I - Prediction score thresholds
For the code that produced the figures, see `src/python/epiclass/utils/notebooks/paper/confidence_threshold.ipynb` ([permalink](https://github.com/rabyj/EpiClass/blob/e87306bccecd3e35d9ffdbb5ddb29fcaa0800a35/src/python/epiclass/utils/notebooks/paper/confidence_threshold.ipynb)).
<!-- figures are supposed to be 700x600, but need some extra margin to not create scroll bar -->
```{=html}
<iframe width="710", height="615", src="../resources/threshold_impact_graph_global_assay_epiclass_11c-11_classes_cdn.html"></iframe>
<iframe width="710", height="615", src="../resources/threshold_impact_graph_global_harmonized_sample_ontology_intermediate-16_classes_cdn.html"></iframe>
```
Supplementary Figure 1H,I: Distribution of aggregated accuracy, F1-score and corresponding file subset size across
varying prediction score thresholds, based on pooled predictions from all cross-validation folds
for the Assay (H) and Biospecimen (I) classifiers.
## Supplementary Figure 2
Performance of EpiClass Assay and Biospecimen classifiers evaluated per training fold across various bin size resolutions and genomic feature sets.
### A,B
Graph assay/biospecimen metrics per 10fold, for reference (no assay breakdown)
```{python}
#| label: supp-fig2ab-plot
#| column: page-left
#| layout-nrow: 2
metrics_supp2 = {name: all_metrics[name] for name in feature_sets_14}
graph_feature_set_metrics(
all_metrics=metrics_supp2, # type: ignore
input_sizes=input_sizes,
boxpoints="all",
width=900,
height=600,
)
```
Supplementary Figure 2A,B: Distribution of accuracy over all files for the Assay (A) or Biospecimen (B) classifier.
### C,D
Define function to compute metrics per assay: `prepare_metric_sets_per_assay`
```{python}
#| label: supp-fig2rest-prep-function
def prepare_metric_sets_per_assay(
all_results: Dict[str, Dict[str, Dict[str, pd.DataFrame]]], verbose: bool = False
) -> Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, float]]]]]:
"""Prepare metric sets per assay.
Args:
all_results (Dict[str, Dict[str, Dict[str, pd.DataFrame]]]): A dictionary containing all results for all feature sets.
Returns:
Dict[str, Dict[str, Dict[str, Dict[str, float]]]]: A dictionary containing all metrics per assay for all feature sets.
Format: {assay: {feature_set: {task_name: {split_name: metric_dict}}}}
"""
if verbose:
print("Loading metadata.")
metadata = metadata_handler.load_metadata("v2")
metadata.convert_classes(ASSAY, ASSAY_MERGE_DICT)
md5_per_assay = metadata.md5_per_class(ASSAY)
md5_per_assay = {k: set(v) for k, v in md5_per_assay.items()}
if verbose:
print("Getting results per assay.")
results_per_assay = {}
for assay_label in ASSAY_ORDER:
if verbose:
print(assay_label)
results_per_assay[assay_label] = {}
for feature_set, task_dict in all_results.items():
if verbose:
print(feature_set)
results_per_assay[assay_label][feature_set] = {}
for task_name, split_dict in task_dict.items():
if verbose:
print(task_name)
results_per_assay[assay_label][feature_set][task_name] = {}
# Only keep the relevant assay
for split_name, split_df in split_dict.items():
if verbose:
print(split_name)
assay_df = split_df[split_df.index.isin(md5_per_assay[assay_label])]
results_per_assay[assay_label][feature_set][task_name][
split_name
] = assay_df
if verbose:
print("Finished getting results per assay. Now computing metrics.")
metrics_per_assay = {}
for assay_label in ASSAY_ORDER:
if verbose:
print(assay_label)
metrics_per_assay[assay_label] = {}
for feature_set, task_dict in results_per_assay[assay_label].items():
if verbose:
print(feature_set)
assay_metrics = split_results_handler.compute_split_metrics(
task_dict, concat_first_level=True
)
inverted_dict = split_results_handler.invert_metrics_dict(assay_metrics)
metrics_per_assay[assay_label][feature_set] = inverted_dict
return metrics_per_assay
```
Define graphing function `graph_feature_set_metrics_per_assay`
```{python}
#| label: supp-fig2rest-graph-function
def graph_feature_set_metrics_per_assay(
all_metrics_per_assay: Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, float]]]]],
input_sizes: Dict[str, int],
logdir: Path | None = None,
sort_by_input_size: bool = False,
name: str | None = None,
y_range: Tuple[float, float] | None = None,
boxpoints: str = "outliers",
) -> None:
"""Graph the metrics for all feature sets, per assay, with separate plots for accuracy and F1-score.
Args:
all_metrics_per_assay (Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, float]]]]]): A dictionary containing all metrics per assay for all feature sets.
Format: {assay: {feature_set: {task_name: {split_name: metric_dict}}}}
input_sizes (Dict[str, int]): A dictionary containing the input sizes for all feature sets.
logdir (Path): The directory where the figures will be saved. If None, the figures will only be displayed.
sort_by_input_size (bool): Whether to sort the feature sets by input size.
name (str|None): The name of the figure.
y_range (Tuple[float, float]|None): The y-axis range for the plots.
boxpoints (str): The type of points to display in the box plots. Defaults to "outliers".
"""
valid_boxpoints = ["all", "outliers"]
if boxpoints not in valid_boxpoints:
raise ValueError(f"Invalid boxpoints value. Choose from {valid_boxpoints}.")
fig_assay_order = [
"rna_seq",
"h3k27ac",
"h3k4me1",
"h3k4me3",
"h3k36me3",
"h3k27me3",
"h3k9me3",
"input",
"wgbs",
]
reference_assay = next(iter(all_metrics_per_assay))
reference_feature_set = next(iter(all_metrics_per_assay[reference_assay]))
metadata_categories = list(
all_metrics_per_assay[reference_assay][reference_feature_set].keys()
)
for _, category in enumerate(metadata_categories):
for metric, metric_name in [
("Accuracy", "Accuracy"),
("F1_macro", "F1-score (macro)"),
]:
fig = go.Figure()
feature_sets = list(all_metrics_per_assay[reference_assay].keys())
unique_feature_sets = set(feature_sets)
for assay in fig_assay_order:
if set(all_metrics_per_assay[assay].keys()) != unique_feature_sets:
raise ValueError("Different feature sets through assays.")
feature_set_order = feature_sets
if sort_by_input_size:
feature_set_order = sorted(
feature_set_order, key=lambda x: input_sizes[x]
)
# Adjust spacing so each assay group has dedicated space based on the number of feature sets
spacing_multiplier = (
1.1 # Increase this multiplier if needed to add more spacing
)
x_positions = {
assay: i * len(feature_set_order) * spacing_multiplier
for i, assay in enumerate(fig_assay_order)
}
for i, feature_set_name in enumerate(feature_set_order):
resolution = (
feature_set_name.replace("_none", "")
.replace("hg38_", "")
.split("_")[0]
)
color = resolution_colors[resolution]
display_name = feature_set_name.replace("_none", "").replace("hg38_", "")
for assay in fig_assay_order:
if feature_set_name not in all_metrics_per_assay[assay]:
continue
tasks_dicts = all_metrics_per_assay[assay][feature_set_name]
if feature_set_name not in input_sizes:
print(f"Skipping {feature_set_name}, no input size found.")
continue
task_name = category
if "split" in task_name:
raise ValueError("Split in task name. Wrong metrics dict.")
try:
task_dict = tasks_dicts[task_name]
except KeyError:
print(
f"Skipping {feature_set_name}, {task_name} for assay {assay}"
)
continue
y_vals = [task_dict[split][metric] for split in task_dict]
hovertext = [
f"{assay} - {display_name} - {split}: {metrics_dict[metric]:.4f}"
for split, metrics_dict in task_dict.items()
]
x_position = x_positions[assay] + i
fig.add_trace(
go.Box(
x=[x_position] * len(y_vals),
y=y_vals,
name=f"{assay}|{display_name}",
boxmean=True,
boxpoints=boxpoints,
marker=dict(size=3, color="black"),
line=dict(width=1, color="black"),
fillcolor=color,
hovertemplate="%{text}",
text=hovertext,
showlegend=False,
legendgroup=display_name,
)
)
# separate box groups
fig.add_vline(
x=x_positions[assay] - 1, line_width=1, line_color="black"
)
# Add dummy traces for the legend
for feature_set_name in feature_set_order:
resolution = (
feature_set_name.replace("_none", "")
.replace("hg38_", "")
.split("_")[0]
)
color = resolution_colors[resolution]
display_name = feature_set_name.replace("_none", "").replace("hg38_", "")
display_name = re.sub(r"\_[\dmkb]+\_coord", "", display_name)
fig.add_trace(
go.Scatter(
name=display_name,
x=[None],
y=[None],
mode="markers",
marker=dict(size=10, color=color),
showlegend=True,
legendgroup=display_name,
)
)
title = f"{category} - {metric_name} (per assay)"
if name is not None:
title += f" - {name}"
fig.update_layout(
width=1250,
height=900,
title_text=title,
xaxis_title="Assay",
yaxis_title=metric_name,
**main_title_settings
)
# Create x-axis labels
fig.update_xaxes(
tickmode="array",
tickvals=[
x_positions[assay] + len(feature_set_order) / 2
for assay in fig_assay_order
],
ticktext=list(x_positions.keys()),
title="Assay",
)
fig.update_layout(
legend=dict(
title="Feature Sets", itemsizing="constant", traceorder="normal"
)
)
if y_range:
fig.update_yaxes(range=y_range)
if logdir:
base_name = f"feature_set_metrics_{category}_{metric}_per_assay"
if name is not None:
base_name = base_name + f"_{name}"
fig.write_html(logdir / f"{base_name}.html")
fig.write_image(logdir / f"{base_name}.svg")
fig.write_image(logdir / f"{base_name}.png")
fig.show()
```
Get prediction scores for multiple feature sets.
```{python}
#| label: supp-fig2rest-get-data
set_selection_name = "feature_sets_14"
all_results = split_results_handler.obtain_all_feature_set_data(
parent_folder=mixed_data_dir,
merge_assays=True,
return_type="split_results",
include_categories=[CELL_TYPE],
include_sets=metric_orders_map[set_selection_name],
exclude_names=["16ct", "27ct", "7c", "chip-seq-only"],
)
```
Compute metrics per assay
```{python}
#| label: supp-fig2rest-prep-metrics
metrics_per_assay = prepare_metric_sets_per_assay(all_results) # type: ignore
```
Reorder feature sets for graphing.
```{python}
#| label: supp-fig2rest-reorder
# Reorder feature sets
feature_set_order = metric_orders_map[set_selection_name]
for assay, feature_sets in list(metrics_per_assay.items()):
metrics_per_assay[assay] = {
feature_set_name: metrics_per_assay[assay][feature_set_name]
for feature_set_name in feature_set_order
}
```
Graph.
```{python}
#| label: supp-fig2rest-plot
#| column: screen-inset-left
#| layout-nrow: 2
graph_feature_set_metrics_per_assay(
all_metrics_per_assay=metrics_per_assay, # type: ignore
input_sizes=input_sizes,
boxpoints="all",
sort_by_input_size=False,
y_range=(0.1, 1.01)
)
```
Supplementary Figure 2C,D: Distribution of accuracy calculated per assay for the Biospecimen classifier. Bin sizes include 10 Mb, 1 Mb, 100 kb, and 10 kb, corresponding to 315, 3,044, 30,321, and 303,114 non-overlapping regions covering the whole-genome, respectively. Various numbers of random 100 kb, 10 kb and 1 kb regions were also used. Gene-based features include 19,864 gene regions, while cis-regulatory elements and methylation regions each comprise 30,320 and 303,114 regions, respectively. Dashed lines represent means, solid lines the medians, boxes the quartiles, whiskers the farthest points within 1.5× the interquartile range, and dots are outliers.
## Supplementary Figure 3
Mislabeled datasets identified with EpiClass.
{width=.column-body}
Supplementary Figure 3: Genome browser representation of the eight EpiATLAS originally mislabeled datasets identified by EpiClass in metadata freeze v1.0 that were discarded in following metadata freezes (purple), along with representative correct datasets. The observed tracks are shown as positive signal, while imputed tracks (where available) are shown as negative signal.
## Supplementary Figure 4
Example of bad quality datasets identified using EpiClass
{width=.column-body}
Supplementary Figure 4: Genome browser representation of some of the EpiATLAS bad quality datasets identified by EpiClass in metadata freeze v1.0 that were discarded in following metadata freezes (purple), along with good quality ones from the same biospecimen. The observed tracks are shown as positive signal, while imputed tracks (where available) are shown as negative signal.